Transformers
The Transformer model, introduced by Vaswani et al. (2017), is a deep architecture solely based on attention mechanisms, omitting traditional convolutional or recurrent layers. It is designed for sequence-to-sequence learning and has been widely applied in language, vision, speech, and reinforcement learning applications. The architecture supports parallel computation and features a short path length between input and output, making it highly efficient for tasks involving sequential data.
Model Components
Multi-Head Self-Attention
- Queries, keys, and values are derived from the same input for self-attention in both the encoder and decoder, enhancing the model's ability to capture different aspects of information from the same input.
- The encoder uses self-attention for input representation, while the decoder uses masked self-attention to ensure outputs are based only on earlier timesteps, maintaining the autoregressive property.
Positional Encoding
- Adds information about the position of tokens in the sequence, compensating for the absence of recurrence in the model.
- Uses sine and cosine functions of different frequencies.
Encoder and Decoder Layers
- Both the encoder and decoder are composed of a stack of layers, each containing two sub-layers: a multi-head self-attention mechanism and a position-wise fully connected feed-forward network.
- A residual connection followed by layer normalization is employed around each of the two sub-layers within a single layer.
Encoder-Decoder Attention
- In the decoder, an additional encoder-decoder attention layer helps the decoder focus on appropriate parts of the input sequence.
- Queries come from the previous decoder layer, and the keys and values come from the output of the encoder.
Positionwise Feed-Forward Networks
- Applies a fully connected feed-forward network to each position separately and identically. This consists of two linear transformations with a ReLU activation in between.
Residual Connection and Layer Normalization
- Each sublayer in the encoder and decoder, with their outputs added to the inputs of the sublayer (residual connection) and normalized (layer normalization).
Usage in Machine Translation
- The Transformer model has shown significant success in machine translation due to its ability to handle sequences efficiently.
- Typically trained using a paired dataset of source and target sentences.
Mathematical Foundations
- Attention function (scaled dot-product attention):
where are the queries, keys, and values respectively, and is the dimension of the keys.
- Each sub-layer output, including embeddings and positional encodings, is scaled or normalized to facilitate stable training dynamics, especially in deeper models.
Challenges and Innovations
- Dealing with long input sequences can be computationally expensive due to the quadratic complexity of the self-attention mechanism with respect to sequence length.
- Variants and improvements such as "Efficient Transformers" address this by approximating the attention mechanism or sparsifying the connections.
Example
Here is the complete block for the PyTorch example of implementing the Transformer model, focusing on both the encoder and the decoder components. This includes the setup, creation of the model, training, and evaluation phases.
import math
import torch
from torch import nn
import torch.nn.functional as F
from d2l import torch as d2l
class PositionWiseFFN(nn.Module):
"""The positionwise feed-forward network."""
def __init__(self, ffn_num_hiddens, ffn_num_outputs):
super().__init__()
self.dense1 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)
self. relu = nn.ReLU ()
self. dense2 = nn.Linear (ffn_num_outputs, ffn_num_hiddens)
def forward (self, X):
return self.dense2 (self.relu (self.dense1 (X)))
class AddNorm (nn. Module):
"""Residual connection followed by layer normalization."""
def __init__(self, normalized_shape, dropout):
super ().__init__()
self. dropout = nn.Dropout (dropout)
self. ln = nn.LayerNorm (normalized_shape)
def forward (self, X, Y):
return self.ln (self.dropout (Y) + X)
class MultiHeadAttention (nn. Module):
"""Multi-head attention."""
# `num_hiddens`: d_v, `num_heads`: h
def __init__(self, num_hiddens, num_heads, dropout, bias=False):
super ().__init__()
self. num_heads = num_heads
self. attention = d2l.DotProductAttention (dropout)
self. W_q = nn.Linear (num_hiddens, num_hiddens, bias=bias)
self. W_k = nn.Linear (num_hiddens, num_hiddens, bias=bias)
self. W_v = nn.Linear (num_hiddens, num_hiddens, bias=bias)
self. W_o = nn.Linear (num_hiddens, num_hiddens, bias=bias)
def forward (self, queries, keys, values, valid_lens):
# Shape of `queries`, `keys`, or `values`:
# (batch_size, num_queries or num_key-value pairs, num_hiddens)
# Shape of `valid_lens`:
# (batch_size,) or (batch_size, num_queries)
# After transposing, shape of output `queries`, `keys`, or `values`:
# (batch_size * num_heads, num_queries or num_key-value pairs, num_hiddens / num_heads)
queries = self. W_q (queries). reshape (queries. shape[0], queries. shape[1], self. num_heads, -1). permute (0, 2, 1, 3)
keys = self. W_k (keys). reshape (keys. shape[0], keys. shape[1], self. num_heads, -1). permute (0, 2, 1, 3)
values = self. W_v (values). reshape (values. shape[0], values. shape[1], self. num_heads, -1). permute (0, 2, 1, 3)
if valid_lens is not None:
valid_lens = torch. repeat_interleave (valid_lens, repeats=self. num_heads, dim=0)
output = self.attention (queries, keys, values, valid_lens)
output = output.permute (0, 2, 1, 3). reshape (output. shape[0], output. shape[1], -1)
return self. W_o (output)
class TransformerEncoderBlock (nn. Module):
"""Transformer encoder block."""
def __init__(self, num_hiddens, ffn_num_hiddens, num_heads, dropout, use_bias=False):
super ().__init__()
self. attention = MultiHeadAttention (num_hiddens, num_heads, dropout, use_bias)
self. addnorm1 = AddNorm (num_hiddens, dropout)
self. ffn = PositionWiseFFN (num_hiddens, ffn_num_hiddens)
self. addnorm2 = AddNorm (num_hiddens, dropout)
def forward (self, X, valid_lens):
Y = self.addnorm1 (X, self.attention (X, X, X, valid_lens))
return self.addnorm2 (Y, self.ffn (Y))
class TransformerEncoder (nn. Module):
"""Transformer encoder."""
def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads, num_layers, dropout):
super ().__init__()
self. embedding = nn.Embedding (vocab_size, num_hiddens)
self. pos_encoding = d2l.PositionalEncoding (num_hiddens, dropout)
self. blks = nn.ModuleList ([
TransformerEncoderBlock (num_hiddens, ffn_num_hiddens, num_heads, dropout) for _ in range (num_layers)
])
def forward (self, X, valid_lens, *args):
X = self. pos_encoding (self.embedding (X) * math.sqrt (num_hiddens))
for blk in self. blks:
X = blk (X, valid_lens)
return X
class TransformerDecoderBlock (nn. Module):
"""Transformer decoder block."""
def __init__(self, num_hiddens, ffn_num_hiddens, num_heads, dropout, i, use_bias=False):
super ().__init__()
self. i = i
self. attention1 = MultiHeadAttention (num_hiddens, num_heads, dropout, use_bias)
self. addnorm1 = AddNorm (num_hiddens, dropout)
self. attention2 = MultiHeadAttention (num_hiddens, num_heads, dropout, use_bias)
self. addnorm2 = AddNorm (num_hiddens, dropout)
self. ffn = PositionWiseFFN (num_hiddens, ffn_num_hiddens)
self. addnorm3 = AddNorm (num_hiddens, dropout)
def forward (self, X, state):
enc_outputs, enc_valid_lens = state[0], state[1]
if state[2][self. i] is None:
key_values = X
else:
key_values = torch.cat ((state[2][self. i], X), dim=1)
state[2][self. i] = key_values
if self. training:
batch_size, num_steps, _ = X.shape
dec_valid_lens = torch.arange (1, num_steps + 1, device=X.device). repeat (batch_size, 1)
else:
dec_valid_lens = None
X2 = self.attention1 (X, key_values, key_values, dec_valid_lens)
Y = self.addnorm1 (X, X2)
Y2 = self.attention2 (Y, enc_outputs, enc_outputs, enc_valid_lens)
Z = self.addnorm2 (Y, Y2)
return self.addnorm3 (Z, self.ffn (Z)), state
class TransformerDecoder (nn. Module):
"""Transformer decoder."""
def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads, num_layers, dropout):
super ().__init__()
self. num_hiddens = num_hiddens
self. num_layers = num_layers
self. embedding = nn.Embedding (vocab_size, num_hiddens)
self. pos_encoding = d2l.PositionalEncoding (num_hiddens, dropout)
self. blks = nn.ModuleList ()
for i in range (num_layers):
self.blks.append (TransformerDecoderBlock (num_hiddens, ffn_num_hiddens, num_heads, dropout, i))
self. dense = nn.Linear (num_hiddens, vocab_size)
def init_state (self, enc_outputs, enc_valid_lens):
return [enc_outputs, enc_valid_lens, [None] * self. num_layers]
def forward (self, X, state):
X = self. pos_encoding (self.embedding (X) * math.sqrt (self. num_hiddens))
for blk in self. blks:
X, state = blk (X, state)
return self.dense (X), state
num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10
ffn_num_hiddens, num_heads = 64, 4
key_size, query_size, value_size = num_hiddens, num_hiddens, num_hiddens
norm_shape = [num_hiddens]
train_iter, src_vocab, tgt_vocab = d2l. load_data_nmt (batch_size, num_steps)
encoder = TransformerEncoder (
len (src_vocab), num_hiddens, ffn_num_hiddens, num_heads, num_layers, dropout)
decoder = TransformerDecoder (
len (tgt_vocab), num_hiddens, ffn_num_hiddens, num_heads, num_layers, dropout)
model = d2l.EncoderDecoder (encoder, decoder)
d2l. train_seq2seq (model, train_iter, lr=0.005, num_epochs=50, tgt_vocab=tgt_vocab)